#include <utils/err.h>
#include <seminix/cache.h>
#include <seminix/syscall.h>
#include <seminix/param.h>
#include <cap/rlimit.h>
#include <libseminix/types.h>

static caprlimit_t root_limit[cap_count_max] = INIT_ROOT_RLIMIT;

cap_rlimit_t root_cap_rlimit = {
    .cap = INIT_ROOT_CAP,
    .limit_used = 0,
    .crlim = root_limit,
};

static cap_t *do_createrlimit(cap_rlimit_t *cap_rlimit, seminix_object_t *obj, int rights)
{
    cap_t *new_cap;
    unsigned long limit = cap_create_rlimit(obj);
    struct caprlimit *crlimit = &cap_rlimit->crlim[object_to_cap_type(obj->type)];

    if (rlimit_overflow(limit, crlimit))
        return ERR_PTR(-SERRNO_EOVERFLOW);

    new_cap = cap_create(&cap_rlimit->cap, rights, obj);
    if (IS_ERR(new_cap))
        return new_cap;

    cap_set_rlimit(new_cap, cap_rlimit);
    rlimit_up(new_cap, limit);

    return new_cap;
}

cap_t __init *create_rlimit_obj(cap_rlimit_t *cap_rlimit, seminix_object_t *obj, int rights)
{
    return do_createrlimit(cap_rlimit, obj, rights);
}

SYSCALL_DEFINE3(createrlimit, int, cnode, int, rights, seminix_object_t __user *, object)
{
    int cap_type, index, ret;
    cap_t *new_cap;
    seminix_object_t obj;
    cap_rlimit_t *cap_rlimit;
    cap_cnode_t *cap_cnode;

    if (!object)
        return -SERRNO_EINVAL;

    cap_rlimit = rlimit_get();
    if (!cap_rlimit)
        return -SERRNO_EILLEGAL;

    if (copy_from_user(&obj, object, sizeof (seminix_object_t)))
        return -SERRNO_EFAULT;

    if (seminix_object_type_invalid(obj.type))
        return -SERRNO_EINVAL;

    cap_type = object_to_cap_type(obj.type);
    if (cap_type == cap_null_cap)
        return -SERRNO_EILLEGAL;

    cap_cnode = cnode_get(cnode);
    if (IS_ERR(cap_cnode))
        return PTR_ERR(cap_cnode);

    index = cnode_get_unused_slot(cap_cnode);
    if (index < 0) {
        ret = index;
        goto out;
    }

    new_cap = do_createrlimit(cap_rlimit, &obj, rights);
    if (IS_ERR(new_cap)) {
        ret = PTR_ERR(new_cap);
        goto out;
    }

    cnode_cap_insert_slot(cap_cnode, index, new_cap);
    ret = capdesc_set_desc(cnode, index);
out:
    rlimit_put(cap_rlimit);
    cnode_put(cap_cnode);
    return ret;
}

static int do_setrlimit(cap_rlimit_t *parent, cap_rlimit_t *child,
    int cap_type, unsigned long limit, bool up)
{
    struct caprlimit *src_crlim, *dst_crlim;

    src_crlim = &parent->crlim[cap_type];
    dst_crlim = &child->crlim[cap_type];

    if (!cap_is_parent(&parent->cap, &child->cap))
        return -SERRNO_EILLEGAL;

    if (up) {
        if (limit + src_crlim->rlim_cur > src_crlim->rlim_max ||
            limit + dst_crlim->rlim_max > src_crlim->rlim_max)
            return -SERRNO_EOVERFLOW;

        src_crlim->rlim_cur += limit;
        dst_crlim->rlim_max += limit;
    } else {
        if (src_crlim->rlim_cur < limit ||
            dst_crlim->rlim_max - limit < dst_crlim->rlim_cur)
            return -SERRNO_EOVERFLOW;

        src_crlim->rlim_cur -= limit;
        dst_crlim->rlim_max -= limit;
    }

    return 0;
}

SYSCALL_DEFINE4(setrlimit, int, rlimit, int, cap_type, unsigned long, limit, bool, up)
{
    int ret;
    cap_t *cap = NULL;
    cap_rlimit_t *cap_rlimit;

    if (seminix_cap_type_invalid(cap_type))
        return -SERRNO_EINVAL;

    cap = cnode_capget(rlimit, cap_rlimit_cap);
    if (IS_ERR(cap))
        return PTR_ERR(cap);

    cap_rlimit = rlimit_get();
    if (!cap_rlimit) {
        ret = -SERRNO_EILLEGAL;
        goto out;
    }

    ret = do_setrlimit(cap_rlimit, (cap_rlimit_t *)cap, cap_type, limit, up);
    rlimit_put(cap_rlimit);
out:
    cnode_capput(cap);
    return ret;
}

SYSCALL_DEFINE3(getrlimit, int, rlimit, int, cap_type, struct caprlimit __user *, crlim)
{
    int ret = 0;
    cap_t *cap = NULL;
    cap_rlimit_t *cap_rlimit;

    if (!crlim || seminix_cap_type_invalid(cap_type))
        return -SERRNO_EINVAL;

    cap = cnode_capget(rlimit, cap_rlimit_cap);
    if (IS_ERR(cap))
        return PTR_ERR(cap);

    cap_rlimit = (cap_rlimit_t *)cap;
    if (copy_to_user(crlim, &cap_rlimit->crlim[cap_type], sizeof (struct caprlimit)) != 0)
        ret = -SERRNO_EFAULT;

    cnode_capput(cap);
    return ret;
}

cap_rlimit_t *rlimit_get(void)
{
    if (current->cap_rlimit) {
        capget(&current->cap_rlimit->cap);
        return current->cap_rlimit;
    }
    return NULL;
}

static struct kmem_cache *cap_rlimit_cache;

static __init int rlimit_cap_init(void)
{
    cap_rlimit_cache = KMEM_CACHE(cap_rlimit, SLAB_PANIC);

    return 0;
}
userver_initcall(rlimit_cap_init)

static struct cap_rlimit *rlimit_create_cap(void)
{
    struct cap_rlimit *new_rlimit;

    new_rlimit = kmem_cache_alloc(cap_rlimit_cache, GFP_KERNEL | GFP_ZERO);
    if (!new_rlimit)
        return ERR_PTR(-ENOMEM);

    return new_rlimit;
}

static cap_t *rlimit_cap_create(seminix_object_t *object)
{
    int ret;
    struct cap_rlimit *new_rlimit;
    struct caprlimit *rlimit;

    new_rlimit = rlimit_create_cap();
    if (IS_ERR(new_rlimit))
        return (cap_t *)new_rlimit;

    ret = -SERRNO_ENOMEM;
    rlimit = kcalloc(cap_count_max, sizeof (struct caprlimit), GFP_KERNEL);
    if (!rlimit)
        goto free_cap_rlimit;

    new_rlimit->limit_used = 0;
    new_rlimit->crlim = rlimit;

    return &new_rlimit->cap;

free_cap_rlimit:
    kmem_cache_free(cap_rlimit_cache, new_rlimit);
    return ERR_PTR(ret);
}

static cap_t *rlimit_cap_dup(cap_t *cap)
{
    struct cap_rlimit *new_rlimit;

    assert(cap_get_cap_type(cap) == cap_rlimit_cap);

    new_rlimit = rlimit_create_cap();
    if (IS_ERR(new_rlimit))
        return (cap_t *)new_rlimit;

    new_rlimit->crlim = ((struct cap_rlimit *)cap)->crlim;

    return &new_rlimit->cap;
}
#if 0
static int rlimit_cap_prepare_remove(cap_t *cap)
{
    cap_rlimit_t *cap_rlimit = (cap_rlimit_t *)cap;

    assert(cap_get_cap_type(cap) == cap_rlimit_cap);

    if (cap_rlimit->limit_used)
        return -SERRNO_EBUSY;
    return 0;
}
#endif
static void rlimit_cap_revoke(cap_t *cap)
{
    assert(cap_get_cap_type(cap) == cap_rlimit_cap);

    kmem_cache_free(cap_rlimit_cache, cap);
}

static void rlimit_cap_delete(cap_t *cap)
{
    cap_rlimit_t *cap_rlimit = (cap_rlimit_t *)cap;

    assert(cap_get_cap_type(cap) == cap_rlimit_cap);

    kfree(cap_rlimit->crlim);
    kmem_cache_free(cap_rlimit_cache, cap_rlimit->crlim);
}

const struct cap_ops rlimit_cap_ops __ro_after_init = {
    .cap_create = rlimit_cap_create,
    .cap_dup = rlimit_cap_dup,
  //  .cap_prepare_revoke = rlimit_cap_prepare_remove,
    .cap_revoke = rlimit_cap_revoke,
  //  .cap_prepare_delete = rlimit_cap_prepare_remove,
    .cap_delete = rlimit_cap_delete,
};
